import inspect
import math
import os
import sys

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models

currentdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe())))
sys.path.insert(0,currentdir)

"""
Implementation of Resnet+LSTM
"""
class Res_LSTM(nn.Module):
    def __init__(self, sample_size=256, sample_duration=16, num_classes=100,
                lstm_hidden_size=512, lstm_num_layers=1, arch="resnet18"):
        super(Res_LSTM, self).__init__()
        self.sample_size = sample_size
        self.sample_duration = sample_duration
        self.num_classes = num_classes

        # network params
        self.lstm_hidden_size = lstm_hidden_size
        self.lstm_num_layers = lstm_num_layers

        # network architecture
        if arch == "resnet18":
            resnet = models.resnet18(pretrained=True)
        elif arch == "resnet34":
            resnet = models.resnet34(pretrained=True)
        elif arch == "resnet50":
            resnet = models.resnet50(pretrained=True)
        elif arch == "resnet101":
            resnet = models.resnet101(pretrained=True)

        # delete the last fc layer 去掉ResNet最后一个全连接层
        modules = list(resnet.children())[:-1] # .children(): 返回网络模型里的组成元素(返回的是最外层的元素)
        self.resnet = nn.Sequential(*modules)
        self.lstm = nn.LSTM(
            input_size=resnet.fc.in_features, # LSTM的输入维度 = ResNet全连接层的输入维度
            hidden_size=self.lstm_hidden_size,
            num_layers=self.lstm_num_layers,
            batch_first=True,)
        # print("resnet.fc.in_features,  ", resnet.fc.in_features)

        self.fc1 = nn.Linear(self.lstm_hidden_size, self.num_classes)

    def forward(self, x):
        # CNN
        cnn_embed_seq = []
        # x: (batch_size, channel, t, h, w)
        for t in range(x.size(2)):
            # with torch.no_grad():
            # x_in = x[:, :, t, :, :] # (batch_size, 3, 帧数=16, h, w)
            out = self.resnet(x[:, :, t, :, :]) # (batch_size,2048,1,1)
            # print(out.shape)
            out = out.view(out.size(0), -1) # (batch_size,2048)
            cnn_embed_seq.append(out)

        cnn_embed_seq = torch.stack(cnn_embed_seq, dim=0) # Res50: size:(16, batch_size, 2048)
        # print(cnn_embed_seq.shape)
        # batch first
        cnn_embed_seq = cnn_embed_seq.transpose_(0, 1) # Res50: size:(batch_size, 16, 2048)

        # LSTM
        # use faster code paths
        self.lstm.flatten_parameters()
        out, (h_n, c_n) = self.lstm(cnn_embed_seq, None) # out:(batch_size, 帧数=16, lstm_hidden_size=512)
        # MLP
        out = self.fc1(out[:, -1, :]) # 取最后一个作为fc输入size:(batch_size, lstm_hidden_size=512), out-size:(batch_size, num_classes)

        return out


# Test
if __name__ == '__main__':
    import sys
    sys.path.append("..")
    import torchvision.transforms as transforms
    from torch.utils.data import DataLoader
    from dataset import CSL_Isolated

    # Device setting
    if torch.cuda.is_available():
        device = torch.device("cuda")
    else:
        device = torch.device("cpu")
    # device = torch.device("cuda")
    print("device: ", device)
    
    sample_size = 256
    sample_duration = 16
    num_classes = 2
    batch_size = 4
    transform = transforms.Compose([transforms.Resize([sample_size, sample_size]), transforms.ToTensor(), transforms.Normalize(mean=[0.5], std=[0.5])])
    dataset = CSL_Isolated(data_path="D:\Works\Hisi_codes\SLRDataset_isolate\color_img",
        label_path="D:\Works\Hisi_codes\SLRDataset_isolate\gloss_label_2.txt", frames=sample_duration,
        num_classes=num_classes, transform=transform)
        
        
    Res_lstm_net = Res_LSTM(sample_size=sample_size, sample_duration=sample_duration, num_classes=101, arch="resnet50").to(device)
    Res_lstm_net.load_state_dict(torch.load('D:\\Works\\Hisi_codes\\SLR-master\\res50lstm_models\\res50lstm_epoch026.pth'))

    # modules = list(Res_lstm_net.children()) # 取出网络定义的各个层再重组后，没有重写前行传播forward，因此会出错
    # Res_lstm_net = nn.Sequential(*modules)
    
    # # 只取一个数据进行测试
    # input = dataset[0]['data'].unsqueeze(0)
    # print(input.shape)
    # print(type(input))
    # # print(input)
    # Res_lstm_net.eval()
    # output = Res_lstm_net(input)
    # print(output)
    
    # 使用dataloader，一次测试一batch_size
    train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, pin_memory=True)
    for batch_idx, data in enumerate(train_loader):
        # get the inputs and labels
        inputs, labels = data['data'].to(device), data['label'].to(device)
        
        outputs = Res_lstm_net(inputs)
        if isinstance(outputs, list):
            outputs = outputs[0]
        print(outputs)
        break # 测试一个batch_size就跳出
    
    print()

